{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04204.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04204.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sXxfBYJA9yvx",
        "colab_type": "text"
      },
      "source": [
        "## 4.4 自定义层"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vU322NBV-JG0",
        "colab_type": "text"
      },
      "source": [
        "### 4.4.1 不含模型参数的自定义层"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KTjHqN51-Qh-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "from torch import nn \n",
        "\n",
        "class CenteredLayer(nn.Module):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(CenteredLayer, self).__init__(**kwargs)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return x - x.mean()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mfe9scRF-lo-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "06648c30-7d55-4162-be9e-464b8c6f0118"
      },
      "source": [
        "layer = CenteredLayer()\n",
        "layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([-2., -1.,  0.,  1.,  2.])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Joi_JK7J-ywG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fN-D95YZ-75t",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "35893931-ecf4-492c-a836-825fbee1f5c1"
      },
      "source": [
        "y = net(torch.rand(4, 8))\n",
        "y.mean().item()"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "4.656612873077393e-10"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cpACocwi_GKu",
        "colab_type": "text"
      },
      "source": [
        "### 4.4.2 含模型参数的自定义层"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D7QQaQPO_Nvl",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 151
        },
        "outputId": "d21658c9-f0bc-49ef-c50d-0da9b0772cb1"
      },
      "source": [
        "class MyListDense(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(MyListDense, self).__init__()\n",
        "        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])\n",
        "        self.params.append(nn.Parameter(torch.randn(4, 1)))\n",
        "\n",
        "    def forward(self, x):\n",
        "        for i in range(len(self.params)):\n",
        "            x = torch.mm(x, self.params[i])\n",
        "        return x \n",
        "\n",
        "net = MyListDense()\n",
        "net"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "MyListDense(\n",
              "  (params): ParameterList(\n",
              "      (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
              "      (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
              "      (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
              "      (3): Parameter containing: [torch.FloatTensor of size 4x1]\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FPZQsp7Z_5X5",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 134
        },
        "outputId": "35d20e0d-7fc4-4aa6-b12c-3fee142a7dde"
      },
      "source": [
        "class MyDictDense(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(MyDictDense, self).__init__()\n",
        "        self.params = nn.ParameterDict({\n",
        "            'linear1': nn.Parameter(torch.randn(4, 4)), \n",
        "            'linear2': nn.Parameter(torch.randn(4, 1)),\n",
        "        })\n",
        "        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})\n",
        "\n",
        "    def forward(self, x, choice='linear1'):\n",
        "        return torch.mm(x, self.params[choice])\n",
        "\n",
        "net = MyDictDense()\n",
        "net"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "MyDictDense(\n",
              "  (params): ParameterDict(\n",
              "      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
              "      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
              "      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0bHS4yGuA1EK",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "73e66eeb-0aeb-4bde-fa17-f0d118fed098"
      },
      "source": [
        "x = torch.ones(1, 4)\n",
        "net(x, 'linear1')"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[ 0.0299, -0.1434, -1.7719,  2.3091]], grad_fn=<MmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TcsuCkZ-BBCh",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "af9f3c9c-ee5a-42b9-c966-78fe1f25699f"
      },
      "source": [
        "net(x, 'linear2')"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[-1.7349]], grad_fn=<MmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2JIKYUeOBDD6",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "1879a994-8725-4a46-8a5b-dd401aa90188"
      },
      "source": [
        "net(x, 'linear3')"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[-1.9762,  1.9643]], grad_fn=<MmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8G6MH7XQBFW4",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 302
        },
        "outputId": "75732164-7f32-465b-8860-ab60a78a3319"
      },
      "source": [
        "net = nn.Sequential(\n",
        "    MyDictDense(), \n",
        "    MyListDense(), \n",
        ")\n",
        "\n",
        "net "
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Sequential(\n",
              "  (0): MyDictDense(\n",
              "    (params): ParameterDict(\n",
              "        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
              "        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]\n",
              "        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]\n",
              "    )\n",
              "  )\n",
              "  (1): MyListDense(\n",
              "    (params): ParameterList(\n",
              "        (0): Parameter containing: [torch.FloatTensor of size 4x4]\n",
              "        (1): Parameter containing: [torch.FloatTensor of size 4x4]\n",
              "        (2): Parameter containing: [torch.FloatTensor of size 4x4]\n",
              "        (3): Parameter containing: [torch.FloatTensor of size 4x1]\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7t2eNCvZBQQ5",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "14645358-fdfe-4670-ae0b-6c0cfc9fbe31"
      },
      "source": [
        "net(x)"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[-19.8875]], grad_fn=<MmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 19
        }
      ]
    }
  ]
}